from torch import nn
import torch.nn.functional as F


class TransformerEncoderLayer(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, num_heads=8, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=input_size,
                nhead=num_heads,
                dim_feedforward=hidden_size,
                dropout=dropout
            ),
            num_layers=num_layers
        )

    def forward(self, x):
        output = self.transformer_encoder(x)
        return output


class gotmdmodel(nn.Module):
    g_p = 0
    g_batch = 0

    def set_val(self, n_epoch, n_batch):
        self.g_epoch = n_epoch
        self.g_batch = n_batch

    def __init__(self):
        super(gotmdmodel, self).__init__()
        # 9.特征扁平化
        self.flatten9 = nn.Flatten()
        # Transformer层
        self.transformer_encoder = TransformerEncoderLayer(input_size=128, hidden_size=512, num_layers=6)
        # 全连接层
        self.L10 = nn.Linear(12, 128)
        self.L11 = nn.Linear(128, 64)
        self.L12 = nn.Linear(64, 4)

    def forward(self, x):
        # 3层全连接
        x = self.flatten9(x)
        x = x.unsqueeze(0)  # 添加批次维度
        x = x.permute(0, 2, 1)
        x = self.transformer_encoder(x)
        x = x.permute(0, 1, 2)
        x = x.squeeze(0)  # 去除批次维度
        x = x.permute(1, 0)
        x = self.flatten9(x)
        x = self.L10(x)
        x = self.L11(x)
        x = self.L12(x)
        x = F.softmax(x, dim=1)
        return x
